from __future__ import annotations

import asyncio
import time
import types
import pytest
from fastapi.testclient import TestClient

from tldw_Server_API.app.main import app
from tldw_Server_API.app.core.DB_Management.Workflows_DB import WorkflowsDatabase
from tldw_Server_API.app.api.v1.endpoints import workflows as wf_mod
from tldw_Server_API.app.core.AuthNZ.User_DB_Handling import User, get_request_user


pytestmark = pytest.mark.integration


@pytest.fixture()
def client_with_wf(tmp_path):
    db = WorkflowsDatabase(str(tmp_path / "wf.db"))

    async def override_user():
        return User(id=1, username="tester", email="t@e.com", is_active=True, is_admin=True)

    def override_db():
        return db

    app.dependency_overrides[get_request_user] = override_user
    app.dependency_overrides[wf_mod._get_db] = override_db

    with TestClient(app) as client:
        yield client

    app.dependency_overrides.clear()


def _wait_terminal(client: TestClient, run_id: str, timeout_s: float = 5.0):
    deadline = time.time() + timeout_s
    while time.time() < deadline:
        d = client.get(f"/api/v1/workflows/runs/{run_id}").json()
        if d["status"] in ("succeeded", "failed", "cancelled"):
            return d
        time.sleep(0.02)
    return client.get(f"/api/v1/workflows/runs/{run_id}").json()


def test_cancel_during_delay_step(client_with_wf: TestClient):
    client = client_with_wf
    # Single delay step long enough to issue cancel in-between
    definition = {
        "name": "cancel-delay",
        "version": 1,
        "steps": [
            {"id": "d1", "type": "delay", "config": {"milliseconds": 1500}},
        ],
    }
    wid = client.post("/api/v1/workflows", json=definition).json()["id"]
    run_id = client.post(f"/api/v1/workflows/{wid}/run", json={"inputs": {}}).json()["run_id"]
    # Cancel very soon after start
    r = client.post(f"/api/v1/workflows/runs/{run_id}/cancel")
    assert r.status_code == 200
    d = _wait_terminal(client, run_id)
    assert d["status"] == "cancelled"
    # Ensure a run_cancelled event exists
    ev = client.get(f"/api/v1/workflows/runs/{run_id}/events").json()
    assert any(e.get("event_type") == "run_cancelled" for e in ev)


def test_retry_backoff_persists_attempts(client_with_wf: TestClient):
    client = client_with_wf
    # Configure a quick timeout and one retry; start async to avoid blocking
    definition = {
        "name": "retry-persist",
        "version": 1,
        "steps": [
            {"id": "s1", "type": "prompt", "retry": 1, "timeout_seconds": 0.02, "config": {"template": "Y", "simulate_delay_ms": 100}},
        ],
    }
    wid = client.post("/api/v1/workflows", json=definition).json()["id"]
    run_id = client.post(f"/api/v1/workflows/{wid}/run?mode=async", json={"inputs": {}}).json()["run_id"]

    # Poll DB for attempts (>=2: initial + retry) without requiring terminal state
    db: WorkflowsDatabase = app.dependency_overrides[wf_mod._get_db]()
    deadline = time.time() + 3.0
    attempt = 0
    while time.time() < deadline:
        row = db._conn.cursor().execute(
            "SELECT attempt FROM workflow_step_runs WHERE run_id = ? ORDER BY started_at DESC LIMIT 1",
            (run_id,),
        ).fetchone()
        if row:
            attempt = int(row[0])
            if attempt >= 2:
                break
        time.sleep(0.02)
    assert attempt >= 2


def test_backoff_cap_env_applied(monkeypatch, client_with_wf: TestClient):
    client = client_with_wf
    # Cap the backoff to 1s and intercept asyncio.sleep to capture durations
    monkeypatch.setenv("WORKFLOWS_BACKOFF_CAP_SECONDS", "1")

    sleep_calls = []
    orig_sleep = asyncio.sleep

    async def _fake_sleep(dur: float):
        sleep_calls.append(float(dur))
        await orig_sleep(0)
        return None

    monkeypatch.setattr("asyncio.sleep", _fake_sleep)

    # Force a timeout to trigger retry backoff in the engine
    definition = {
        "name": "retry-cap",
        "version": 1,
        "steps": [
            {"id": "s1", "type": "prompt", "retry": 1, "timeout_seconds": 0.01, "config": {"template": "Y", "simulate_delay_ms": 100}},
        ],
    }
    wid = client.post("/api/v1/workflows", json=definition).json()["id"]
    run_id = client.post(f"/api/v1/workflows/{wid}/run?mode=async", json={"inputs": {}}).json()["run_id"]

    # Wait briefly for one retry cycle to occur
    import time
    t0 = time.time()
    while time.time() - t0 < 1.0 and len(sleep_calls) == 0:
        time.sleep(0.02)

    assert len(sleep_calls) >= 1
    # The first backoff should be <= cap(1) + jitter(~<=0.75). Assert sanity bound 2.0s
    assert sleep_calls[0] <= 2.0
